"""
ESMFold model with continuous linker placed at the embedding layer of folding trunk

esmfold_v1: esm + folding_trunk + structure_module
folding_trunk
    Embedding(23, 1024, padding_idx=0)
    n_tokens_embed = 23
    pad_idx = 0
    unk_idx = n_tokens_embed - 2
    mask_idx = n_tokens_embed - 1

distogram_head: Linear(in_features=128, out_features=64, bias=True)
"""

import os
import typing as T
from sklearn.metrics.pairwise import cosine_similarity

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from openfold.data.data_transforms import make_atom14_masks
from openfold.np import residue_constants
from openfold.utils.loss import compute_predicted_aligned_error, compute_tm

from esm.esmfold.v1.categorical_mixture import categorical_lddt
from esm.esmfold.v1.trunk import FoldingTrunk, FoldingTrunkConfig
from esm.esmfold.v1.misc import (
    batch_encode_sequences,
    collate_dense_tensors,
    output_to_pdb,
)
from esm.pretrained import load_model_and_alphabet_local
from esm.esmfold.v1.esmfold import ESMFold

from utils.logger import Logger
logger = Logger.logger


def softmax_cross_entropy(logits, labels):
    loss = -1 * torch.sum(labels * F.log_softmax(logits, dim=-1), dim=-1)
    return loss


class ESMFold_ContinuousLinker_v1(ESMFold):
    def __init__(self, linker_len, esm_location, esmfold_config=None, **kwargs):
        nn.Module.__init__(self)     
        self.cfg = esmfold_config if esmfold_config else ESMFoldConfig(**kwargs)
        cfg = self.cfg

        self.distogram_bins = 64

        self.esm, self.esm_dict = load_model_and_alphabet_local(esm_location)

        self.esm.requires_grad_(False)
        self.esm.half()

        self.esm_feats = self.esm.embed_dim
        self.esm_attns = self.esm.num_layers * self.esm.attention_heads
        self.register_buffer("af2_to_esm", ESMFold._af2_to_esm(self.esm_dict))
        self.esm_s_combine = nn.Parameter(torch.zeros(self.esm.num_layers + 1))

        c_s = cfg.trunk.sequence_state_dim
        c_z = cfg.trunk.pairwise_state_dim

        self.esm_s_mlp = nn.Sequential(
            LayerNorm(self.esm_feats),
            nn.Linear(self.esm_feats, c_s),
            nn.ReLU(),
            nn.Linear(c_s, c_s),
        )

        # 0 is padding, N is unknown residues, N + 1 is mask.
        self.n_tokens_embed = residue_constants.restype_num + 3
        self.pad_idx = 0
        self.unk_idx = self.n_tokens_embed - 2
        self.mask_idx = self.n_tokens_embed - 1
        self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)

        self.trunk = FoldingTrunk(**cfg.trunk)

        self.distogram_head = nn.Linear(c_z, self.distogram_bins)
        self.ptm_head = nn.Linear(c_z, self.distogram_bins)
        self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
        self.lddt_bins = 50
        self.lddt_head = nn.Sequential(
            nn.LayerNorm(cfg.trunk.structure_module.c_s),
            nn.Linear(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim),
            nn.Linear(cfg.lddt_head_hid_dim, cfg.lddt_head_hid_dim),
            nn.Linear(cfg.lddt_head_hid_dim, 37 * self.lddt_bins),
        )

        # add linker embedding
        self.linker_embedding = nn.Embedding(linker_len, c_s)
        self.linker_len = linker_len
        self.linker_tokens = torch.arange(self.linker_len)

        
    def forward(
        self,
        aa: torch.Tensor,
        lengths: T.Optional[list] = None,
        s_s_0: T.Optional[torch.Tensor] = None,
        mask: T.Optional[torch.Tensor] = None,
        residx: T.Optional[torch.Tensor] = None,
        masking_pattern: T.Optional[torch.Tensor] = None,
        num_recycles: T.Optional[int] = None,
    ):
        """Runs a forward pass given input tokens. Use `model.infer` to
        run inference from a sequence.

        Args:
            aa (torch.Tensor): Tensor containing indices corresponding to amino acids. Indices match
                openfold.np.residue_constants.restype_order_with_x.
            lengths: list of list, [[len1, len2],] 
            s_s_0: sequence representation from esm2 (B, seq_len, 1024), if given, we will use it inestead of computing again
            mask (torch.Tensor): Binary tensor with 1 meaning position is unmasked and 0 meaning position is masked.
            residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
            masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
                as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
                different masks are provided.
            num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
                recycles, which is 3.
        """

        if mask is None:
            mask = torch.ones_like(aa)

        B = aa.shape[0]
        L = aa.shape[1]
        device = aa.device

        if residx is None:
            residx = torch.arange(L, device=device).expand_as(aa)

        # === ESM ===
        if s_s_0 is None:

            esmaa = self._af2_idx_to_esm_idx(aa, mask)

            if masking_pattern is not None:
                esmaa = self._mask_inputs_to_esm(esmaa, masking_pattern)

            esm_s = self._compute_language_model_representations(esmaa)

            # Convert esm_s to the precision used by the trunk and
            # the structure module. These tensors may be a lower precision if, for example,
            # we're running the language model in fp16 precision.
            esm_s = esm_s.to(self.esm_s_combine.dtype)

            esm_s = esm_s.detach()

            # === preprocessing ===
            esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) 

            s_s_0 = self.esm_s_mlp(esm_s)    # seq representation


        s_z_0 = s_s_0.new_zeros(B, L, L, self.cfg.trunk.pairwise_state_dim)  

        # === add linker embedding ===
        struc_embed = self.embedding(aa)    # (B, L, 1024) where B=1  float32
        linker_embed = self.linker_embedding(self.linker_tokens.to(s_s_0.device))   #float16
        # struc_embed[0, lengths[0]: lengths[0]+self.linker_len] = linker_embed
        # accommodate multiple chains
        start = 0
        for i in range(len(lengths)-1):
            start += lengths[i]
            struc_embed[0, start: start+self.linker_len] = linker_embed

        s_s_0 += struc_embed

        structure: dict = self.trunk(s_s_0, s_z_0, aa, residx, mask, no_recycles=num_recycles)
        # Documenting what we expect:
        structure = {
            k: v
            for k, v in structure.items()
            if k
            in [
                "s_z",
                "s_s",
                "frames",
                "sidechain_frames",
                "unnormalized_angles",
                "angles",
                "positions",
                "states",
            ]
        }

        disto_logits = self.distogram_head(structure["s_z"])
        disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
        structure["distogram_logits"] = disto_logits

        if not disto_logits.requires_grad:

            lm_logits = self.lm_head(structure["s_s"])
            structure["lm_logits"] = lm_logits

            structure["aatype"] = aa
            make_atom14_masks(structure)

            for k in [
                "atom14_atom_exists",
                "atom37_atom_exists",
            ]:
                structure[k] *= mask.unsqueeze(-1)
            structure["residue_index"] = residx

            lddt_head = self.lddt_head(structure["states"]).reshape(
                structure["states"].shape[0], B, L, -1, self.lddt_bins
            )
            structure["lddt_head"] = lddt_head
            plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
            structure["plddt"] = (
                100 * plddt
            )  # we predict plDDT between 0 and 1, scale to be between 0 and 100.

            ptm_logits = self.ptm_head(structure["s_z"])

            seqlen = mask.type(torch.int64).sum(1)
            structure["ptm_logits"] = ptm_logits
            structure["ptm"] = torch.stack(
                [
                    compute_tm(
                        batch_ptm_logits[None, :sl, :sl], max_bins=31, no_bins=self.distogram_bins
                    )
                    for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
                ]
            )
            structure.update(
                compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
            )

        return structure

    
    def predict(
        self,
        sequences: T.Union[str, T.List[str]],
        lengths: T.List,
        s_s_0: T.Optional[torch.Tensor] = None,
        residx=None,
        masking_pattern: T.Optional[torch.Tensor] = None,
        num_recycles: T.Optional[int] = None,
        residue_index_offset: T.Optional[int] = 512,
        chain_linker: T.Optional[str] = "G" * 25,
    ):
        """Runs a forward pass given input sequences.

        Args:
            sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in,
                each chain should be separated by a ':' token (e.g. "<chain1>:<chain2>:<chain3>").
            residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
            masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
                as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
                different masks are provided.
            num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
                recycles (cfg.trunk.max_recycles), which is 4.
            residue_index_offset (int): Residue index separation between chains if predicting a multimer. Has no effect on
                single chain predictions. Default: 512.
            chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain
                predictions. Default: length-25 poly-G ("G" * 25).
        """
        if isinstance(sequences, str):
            sequences = [sequences]

        aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences(
            sequences, residue_index_offset, chain_linker
        )

        if residx is None:
            residx = _residx
        elif not isinstance(residx, torch.Tensor):
            residx = collate_dense_tensors(residx)

        aatype, mask, residx, linker_mask = map(
            lambda x: x.to(self.device), (aatype, mask, residx, linker_mask)
        )

        if s_s_0 is not None:
            assert aatype.shape[1] == s_s_0.shape[1], 'sequence length of s_s_0 does not match aatype'
            s_s_0=s_s_0.to(self.device)

        output = self.forward(
            aatype,
            lengths=lengths,
            s_s_0=s_s_0, 
            mask=mask,
            residx=residx,
            masking_pattern=masking_pattern,
            num_recycles=num_recycles,
        )

        if not output['distogram_logits'].requires_grad:
            output["atom37_atom_exists"] = output["atom37_atom_exists"] * linker_mask.unsqueeze(2)

            output["mean_plddt"] = (output["plddt"] * output["atom37_atom_exists"]).sum(
                dim=(1, 2)
            ) / output["atom37_atom_exists"].sum(dim=(1, 2))
            output["chain_index"] = chain_index

        return output

    def output_to_pdb(self, output: T.Dict) -> T.List[str]:
        """Returns the pbd (file) string from the model given the model output."""
        return output_to_pdb(output)

    def initialize_linker_embed(self):
        """using the embedding of "G" to initialize the linker embeddings
        """
        G_idx = residue_constants.restype_order_with_x['G']+1   # pad_idx=0, G_idx=7+1=8
        self.linker_embedding.weight.data = self.embedding.weight[G_idx].repeat(self.linker_len, 1) 




def load_pretrained_esmfold(linker_len, model_name='esmfold_3B_v1', model_dir='checkpoint/esmfold'):
    if model_name.endswith(".pt"):  # local, treat as filepath
        model_path = os.path.join(model_dir, model_name)
        model_data = torch.load(str(model_path), map_location="cpu")
    else:  # load from hub
        url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
        model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")

    cfg = model_data["cfg"]["model"]
    model_state = model_data["model"]

    esm_location = os.path.join(model_dir, "esm2_t36_3B_UR50D.pt")
    model = ESMFold_ContinuousLinker_v1(linker_len, esm_location, esmfold_config=cfg)

    expected_keys = set(model.state_dict().keys())
    found_keys = set(model_state.keys())

    missing_essential_keys = []
    for missing_key in expected_keys - found_keys:
        if not missing_key.startswith("esm."):
            missing_essential_keys.append(missing_key)

    if missing_essential_keys:
        logger.info(f"Keys '{', '.join(missing_essential_keys)}' are missing.")

    model.load_state_dict(model_state, strict=False)

    return model


class ContinousLinkerModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        if args.backbone == 'esmfold_v1':
            self.esmfold = load_pretrained_esmfold(args.linker_len, 
                                                model_name='esmfold_3B_v1.pt', 
                                                model_dir=args.backbone_dir)
            logger.info('{} successfully loaded'.format(args.backbone))
            self.esmfold = self.esmfold.cuda()
            if args.precompute_esm2:
                self.esmfold.esm.cpu()   # put esm in CPU to save cuda memory
        self.batch_converter = self.esmfold.esm.alphabet.get_batch_converter()
        self.args = args
        self.linker_len = args.linker_len
  
    def forward(self, 
                dimer_seq: str,
                seq_rep: T.Optional[torch.Tensor] = None,
                labels: T.Optional[torch.Tensor] = None,
                lengths: T.Optional[T.List] = None,
                ):
        """
        Args:
            dimer_seq: chain1: chain2, could be multimer
            seq_rep: the esm2 seq representation of dimer_seq with 25 poly-G linker 
            labels: tensor of shape (len1+len2, len1+len2)
            lengths: [len1, len2]
        Note: only work for batch_size=1
        """
        # of shape (1, seqlen+linker_len, seqlen+linker_len, 64)
        chain_linker = self.linker_len*'G'
        output = self.esmfold.predict(sequences=dimer_seq, 
                                        lengths=lengths,
                                        s_s_0 = seq_rep,
                                        num_recycles=self.args.num_recycles, 
                                        residue_index_offset=self.args.residue_gap,
                                        chain_linker=chain_linker
                                        )
        disto_loss = None
        predicted_pdb = None
        # remove linker
        if self.args.model_mode == "train" or self.args.keep_linker_in_output == False:
            output = self.remove_linker_from_output(output, lengths)
            disto_logits = output['distogram_logits']
            if labels is not None:
                disto_loss = self.distogram_loss(disto_logits, labels, lengths) 
            if self.args.model_mode == 'test' and self.args.output_pdb:
                predicted_pdb = self.esmfold.output_to_pdb(output)[0]
                assert type(predicted_pdb) == str, 'predicted_pdb should be str'
        else:
            disto_logits = output['distogram_logits']
        
        return disto_logits, disto_loss, predicted_pdb

    
    def distogram_loss(self, logits, labels, lengths):
        """
        Updated: chain A distogram loss + chain B distogram loss + lamda * inter_chain dostogram loss
        logits: (N, seqlen, seqlen, num_bins)
        labels: (N, seqlen, seqlen)
        """
        if labels.dtype != torch.int64:
            labels = labels.type(torch.int64)
        labels_ = F.one_hot(labels-1, self.esmfold.distogram_bins) 
        assert labels_.shape == logits.shape, "labels shape does not match logits shape"
        loss = softmax_cross_entropy(logits, labels_)  # (N, seqlen, seqlen), symmetric matirx
        # average across (i,j) inside a dimer, get upper triangle 
        loss = loss.squeeze() 
        seqlen = loss.shape[-1]
        
        len1, len2 = lengths
        chain1_loss = torch.mean(loss[:len1, :len1][torch.triu_indices(len1, len1)])
        chain2_loss = torch.mean(loss[len1:, len1:][torch.triu_indices(len2, len2)])
        inter_chain_loss = torch.mean(loss[:len1, len1:])
        #loss = torch.mean(loss[torch.triu_indices(seqlen, seqlen)])
        loss = chain1_loss + chain2_loss + self.args.inter_weight * inter_chain_loss
        return loss

    def freeze_esmfold(self):
        self.esmfold.requires_grad_(False)
        self.esmfold.linker_embedding.weight.requires_grad = True

    def convert_linker_to_aa_seq(self):
        linker_embed = self.esmfold.linker_embedding.weight.data
        if self.esmfold.linker_embedding.weight.requires_grad:
            linker_embed = linker_embed.detach()
        linker_embed = linker_embed.cpu().numpy()
        fold_embed = self.esmfold.embedding.weight.data.cpu().numpy()[1:-1] #remove pad_idx and mask_idx

        sim = cosine_similarity(linker_embed, fold_embed)
        logger.info('max similarity: {}'.format(sim.max(1)))
        nearest_aa = list(sim.argmax(1)) 

        residx_to_restype = {idx: restype for (restype, idx) in residue_constants.restype_order_with_x.items()}
        aa_seq = ''.join([residx_to_restype[x] for x in nearest_aa])
        logger.info('nearest discrete linker: {}'.format(aa_seq))
        return aa_seq

    def remove_linker_from_output(self, output, lengths):
        
        keep_indexs = []
        start = 0
        for j in range(len(lengths)):
            indexs = [i for i in range(start, start+lengths[j])]
            keep_indexs.extend(indexs)
            start += lengths[j] + self.linker_len
        
        output['distogram_logits'] = output['distogram_logits'][:, keep_indexs, :, :]
        output['distogram_logits'] = output['distogram_logits'][:, :, keep_indexs, :]
        
        if not output['distogram_logits'].requires_grad:
            output['positions'] = output['positions'][:, :, keep_indexs]
            output["residx_atom37_to_atom14"] = output["residx_atom37_to_atom14"][:, keep_indexs]
            output['aatype'] =  output['aatype'][:, keep_indexs]
            # (1, len1+linker_len+len2, 37), 0 in the linker part
            output["atom37_atom_exists"] = output["atom37_atom_exists"][:, keep_indexs]
            output['residue_index'] = output['residue_index'][:, keep_indexs]
            output["plddt"] = output["plddt"][:, keep_indexs]
            output['chain_index'] = output['chain_index'][:, keep_indexs]
        return output
      